[Docs] Add JaxTrainer API Overview to Ray Docs#57182
[Docs] Add JaxTrainer API Overview to Ray Docs#57182matthewdeng merged 19 commits intoray-project:masterfrom
Conversation
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> Finish jax guide Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
There was a problem hiding this comment.
Code Review
This pull request adds a new documentation page for JaxTrainer, which is a great addition. The guide provides a good overview of the API and how to use it with TPUs. I've made a few suggestions to improve the document's correctness and consistency, mainly around fixing some formatting issues, correcting a version number, and aligning the code examples with the recommended public APIs. Addressing these points will make the guide clearer and more accurate for users.
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
|
|
||
| JaxTrainer API | ||
| -------------- | ||
| The `JaxTrainer` is the core component for orchestrating distributed JAX training in Ray Train with TPUs. |
There was a problem hiding this comment.
We should add these to the API references so we can link them here.
There was a problem hiding this comment.
I added it in f634b93., I only linked it in two places though not every time I used JaxTrainer since it seemed excessive.
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
## Why are these changes needed? With #55207 Ray Train now has support for training functions with a JAX backend through the new `JaxTrainer` API. This guide provides a short overview of the API, how to configure with TPUs, and how to edit a JAX script to use Ray Train. TODO: I will link a longer e2e guide with KubeRay, MaxText, and the JaxTrainer on TPUs in GKE --------- Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> Co-authored-by: matthewdeng <matt@anyscale.com> Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
## Why are these changes needed? With ray-project#55207 Ray Train now has support for training functions with a JAX backend through the new `JaxTrainer` API. This guide provides a short overview of the API, how to configure with TPUs, and how to edit a JAX script to use Ray Train. TODO: I will link a longer e2e guide with KubeRay, MaxText, and the JaxTrainer on TPUs in GKE --------- Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> Co-authored-by: matthewdeng <matt@anyscale.com>
## Why are these changes needed? With ray-project#55207 Ray Train now has support for training functions with a JAX backend through the new `JaxTrainer` API. This guide provides a short overview of the API, how to configure with TPUs, and how to edit a JAX script to use Ray Train. TODO: I will link a longer e2e guide with KubeRay, MaxText, and the JaxTrainer on TPUs in GKE --------- Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> Co-authored-by: matthewdeng <matt@anyscale.com> Signed-off-by: Aydin Abiar <aydin@anyscale.com>
## Why are these changes needed? With ray-project#55207 Ray Train now has support for training functions with a JAX backend through the new `JaxTrainer` API. This guide provides a short overview of the API, how to configure with TPUs, and how to edit a JAX script to use Ray Train. TODO: I will link a longer e2e guide with KubeRay, MaxText, and the JaxTrainer on TPUs in GKE --------- Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> Co-authored-by: matthewdeng <matt@anyscale.com> Signed-off-by: Future-Outlier <eric901201@gmail.com>
## Why are these changes needed? With ray-project#55207 Ray Train now has support for training functions with a JAX backend through the new `JaxTrainer` API. This guide provides a short overview of the API, how to configure with TPUs, and how to edit a JAX script to use Ray Train. TODO: I will link a longer e2e guide with KubeRay, MaxText, and the JaxTrainer on TPUs in GKE --------- Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> Co-authored-by: matthewdeng <matt@anyscale.com>
Why are these changes needed?
With #55207 Ray Train now has support for training functions with a JAX backend through the new
JaxTrainerAPI. This guide provides a short overview of the API, how to configure with TPUs, and how to edit a JAX script to use Ray Train.TODO: I will link a longer e2e guide with KubeRay, MaxText, and the JaxTrainer on TPUs in GKE
Related issue number
Checks
git commit -s) in this PR.method in Tune, I've added it in
doc/source/tune/api/under thecorresponding
.rstfile.